import os
import sys
import numpy as np
import random

import seaborn as sns
import matplotlib.pyplot as plt

# from phone_booth_collab_maze import *

ACTIONS = list(range(7))
LEFT, RIGHT, UP, DOWN, NOOP, HINT_UP, HINT_DOWN = ACTIONS

class ReceiverBeliefModel():
    """
    Belief model on the receiver's location for OBL
    """
    def __init__(self, policy, env):
        # Get all possibe positions to create  pos:index in transition matrix value pair
        self.start_pos = env.agent1_loc
        self.length = env.lengths[1]
        self.possible_pos = []
        for i in range(self.length):
            self.possible_pos.append((i, 0))
        # Exit locations
        self.possible_pos.append((self.length - 1, -1))
        self.possible_pos.append((self.length - 1, 1))

        # Create index
        self.pos_index_dict = {}
        self.index_pos_dict = {}
        for i in range(len(self.possible_pos)):
            self.pos_index_dict[self.possible_pos[i]] = i
            self.index_pos_dict[i] = self.possible_pos[i]

        # transition matrix
        self.trans_mat = np.zeros((len(self.possible_pos), len(self.possible_pos)), dtype="float64")
        for i, pos in enumerate(self.possible_pos):
            if(pos[1] == 0):
                # Not at exits
                left_pos = pos[0] - 1 if pos[0] > 0 else 0
                right_pos = pos[0] + 1 if pos[0] < self.length - 1 else self.length - 1
                if(pos[0] == self.length - 1):
                    # Same y as exits
                    self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                    self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT] + policy[NOOP]
                    self.trans_mat[i, self.pos_index_dict[(pos[0], 1)]] = policy[UP]
                    self.trans_mat[i, self.pos_index_dict[(pos[0], -1)]] = policy[DOWN]
                else:
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        self.trans_mat[i, i] = policy[LEFT] + policy[UP] + policy[DOWN] + policy[NOOP]
                    else:
                        # Non end
                        self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        self.trans_mat[i, i] = policy[UP] + policy[DOWN] + policy[NOOP]
        # Initial belief
        self.belief = np.zeros((1, len(self.possible_pos)), dtype="float64")
        self.belief[0, self.pos_index_dict[tuple(self.start_pos)]] = 1.0
        self.pos_indices = [i for i in range(len(self.possible_pos))]

    def update_belief(self):
        self.belief = np.matmul(self.belief, self.trans_mat)

    def sample_belief(self):
        # Need normalization
        sampled_pos_index = np.random.choice(self.pos_indices, replace = False, p = self.belief.squeeze() / self.belief.sum())
        prob = self.belief.squeeze()[sampled_pos_index]
        return  prob, self.index_pos_dict[sampled_pos_index]

    def reset_belief(self):
        self.belief = np.zeros((1, len(self.possible_pos)))
        self.belief[0, self.pos_index_dict[tuple(self.start_pos)]] = 1.0

class SenderBeliefModel():
    """
    Belief model on the sender's location for OBL
    """
    def __init__(self, policy, env):
        # Get all possibe positions to create  pos:index in transition matrix value pair
        self.start_pos = env.agent0_loc
        self.length = env.lengths[0]
        self.possible_pos = []
        self.decoy_booth_pos_x = []
        # based on agent's coordinate scheme
        if(env.num_sender_decoy_booths > 0):
            self.a_based_decoy_booth_locs = [(d_loc[0] - 1, d_loc[1]) for d_loc in env.decoy_booth_locs]
        else:
            self.a_based_decoy_booth_locs = []
        if(hasattr(env, 'booth_coors')):
            for coor in env.booth_coors:
                self.possible_pos.append(coor)
        else:
            self.correct_booth_coor = (self.length - 1, 0)
        for i in range(self.length):
            self.possible_pos.append((i, 0))
        self.possible_pos = list(set(self.possible_pos))
        for ad_loc in self.a_based_decoy_booth_locs:
            self.possible_pos.append(ad_loc)
            self.decoy_booth_pos_x.append(ad_loc[0])

        # Create index
        self.pos_index_dict = {}
        self.index_pos_dict = {}
        for i in range(len(self.possible_pos)):
            self.pos_index_dict[self.possible_pos[i]] = i
            self.index_pos_dict[i] = self.possible_pos[i]

        # transition matrix
        self.trans_mat = np.zeros((len(self.possible_pos), len(self.possible_pos)))
        for i, pos in enumerate(self.possible_pos):
            left_pos = pos[0] - 1 if pos[0] > 0 else 0
            right_pos = pos[0] + 1 if pos[0] < self.length - 1 else self.length - 1
            if(pos[1] != 0):
                # Decoy buttons *decoy buttons could be next to each other
                if(pos[1] == 1):
                    # Up
                    self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, i] = policy[LEFT] + policy[UP] + policy[NOOP]
                        if((right_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Another decoy booth on the right
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        else:
                            # No booth on the right
                            self.trans_mat[i, i] +=  policy[RIGHT]
                    elif(right_pos == pos[0]):
                        # Right end of corridor
                        self.trans_mat[i, i] = policy[RIGHT] + policy[UP] + policy[NOOP]
                        if((left_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Another decoy booth on the left
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        else:
                            # No booth on the left
                            self.trans_mat[i, i] +=  policy[LEFT]
                    else:
                        # Not on ends
                        if((left_pos, pos[1]) in self.a_based_decoy_booth_locs and (right_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Booth on left and right
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[UP] + policy[NOOP]
                        elif((left_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Booth on the left
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, i] = policy[UP] + policy[NOOP] + policy[RIGHT]
                        elif((right_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Booth on the right
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[UP] + policy[NOOP] + policy[LEFT]
                        else:
                            # No booths on left and right
                            self.trans_mat[i, i] = policy[UP] + policy[NOOP] + policy[LEFT] + policy[RIGHT]


                else:
                    # Down
                    self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, i] = policy[LEFT] + policy[DOWN] + policy[NOOP]
                        if((right_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Another decoy booth on the right
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        else:
                            # No booth on the right
                            self.trans_mat[i, i] +=  policy[RIGHT]
                    elif(right_pos == pos[0]):
                        # Right end of corridor
                        self.trans_mat[i, i] = policy[RIGHT] + policy[DOWN] + policy[NOOP]
                        if((left_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Another decoy booth on the left
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        else:
                            # No booth on the left
                            self.trans_mat[i, i] +=  policy[LEFT]
                    else:
                        # Not on ends
                        if((left_pos, pos[1]) in self.a_based_decoy_booth_locs and (right_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Booth on left and right
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[DOWN] + policy[NOOP]
                        elif((left_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Booth on the left
                            self.trans_mat[i, self.pos_index_dict[left_pos]] = policy[LEFT]
                            self.trans_mat[i, i] = policy[DOWN] + policy[NOOP] + policy[RIGHT]
                        elif((right_pos, pos[1]) in self.a_based_decoy_booth_locs):
                            # Booth on the right
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[DOWN] + policy[NOOP] + policy[LEFT]
                        else:
                            # No booths on left and right
                            self.trans_mat[i, i] = policy[DOWN] + policy[NOOP] + policy[LEFT] + policy[RIGHT]
            else:
                if(pos[0] in self.decoy_booth_pos_x):
                    # same y as decoy buttons
                    possible_decoy_up = (pos[0], 1)
                    possible_decoy_down = (pos[0], -1)
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        if(possible_decoy_up in self.a_based_decoy_booth_locs and possible_decoy_down in self.a_based_decoy_booth_locs):
                            # Up and down have decoy buttons
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, i] = policy[LEFT] + policy[NOOP]
                        elif(possible_decoy_up in self.a_based_decoy_booth_locs):
                            # Only Up
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, i] = policy[LEFT] + policy[NOOP] + policy[DOWN]
                        else:
                            # Only Down
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, i] = policy[LEFT] + policy[NOOP] + policy[UP]
                    elif(right_pos == pos[0]):
                        self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        if(possible_decoy_up in self.a_based_decoy_booth_locs and possible_decoy_down in self.a_based_decoy_booth_locs):
                            # Up and down have decoy buttons
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, i] = policy[RIGHT] + policy[NOOP]
                        elif(possible_decoy_up in self.a_based_decoy_booth_locs):
                            # Only Up
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, i] = policy[RIGHT] + policy[NOOP] + policy[DOWN]
                        else:
                            # Only Down
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, i] = policy[RIGHT] + policy[NOOP] + policy[UP]
                    else:
                        if(possible_decoy_up in self.a_based_decoy_booth_locs and possible_decoy_down in self.a_based_decoy_booth_locs):
                            # Up and down have decoy buttons
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] - 1, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] + 1, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[NOOP]
                        elif(possible_decoy_up in self.a_based_decoy_booth_locs):
                            # Only Up
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] - 1, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] + 1, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[NOOP] + policy[DOWN]
                        else:
                            # Only Down
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] - 1, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] + 1, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[NOOP] + policy[UP]
                else:
                    # Only left and right
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        self.trans_mat[i, i] = policy[LEFT] + policy[UP] + policy[DOWN] + policy[NOOP]
                    elif(right_pos == pos[0]):
                        # Right End of corridor
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, i] = policy[RIGHT] + policy[UP] + policy[DOWN] + policy[NOOP]
                    else:
                        self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        self.trans_mat[i, i] = policy[UP] + policy[DOWN] + policy[NOOP]
            self.trans_mat[i, i] += policy[HINT_UP] + policy[HINT_DOWN]

        # Initial belief
        self.belief = np.zeros((1, len(self.possible_pos)))
        self.belief[0, self.pos_index_dict[tuple(self.start_pos)]] = 1.0
        self.pos_indices = [i for i in range(len(self.possible_pos))]

    def update_belief(self, comm_token = 0):
        # If the receiver receives a communication token, it should know that the sender is in the correct phone booth
        if(comm_token != 0):
            self.belief =  np.zeros((1, len(self.possible_pos)))
            self.belief[0, self.pos_index_dict[self.correct_booth_coor]] = 1.0
        else:
            self.belief = np.matmul(self.belief, self.trans_mat)

    def sample_belief(self):
        # Need normalization
        sampled_pos_index = np.random.choice(self.pos_indices, replace = False, p = self.belief.squeeze() / self.belief.sum() if self.belief.shape[1] > 1 else [(self.belief.item())/self.belief.sum()])
        if self.belief.shape[1] > 1:
            prob = self.belief.squeeze()[sampled_pos_index]
        else:
            prob = 1.0
        return  prob, self.index_pos_dict[sampled_pos_index]

    def reset_belief(self):
        self.belief = np.zeros((1, len(self.possible_pos)))
        self.belief[0, self.pos_index_dict[tuple(self.start_pos)]] = 1.0

class MultiPBSenderBeliefModel():
    """
    Belief model on the sender's location for OBL
    """
    def __init__(self, policy, env):
        # Get all possibe positions to create  pos:index in transition matrix value pair
        self.start_pos = env.agent0_loc
        self.length = env.lengths[0]
        self.possible_pos = []
        self.decoy_booth_pos_x = []
        self.non_main_correct_booth_pos_x  = []
        # based on agent's coordinate scheme
        if(env.num_sender_decoy_booths > 0):
            self.a_based_decoy_booth_locs = [(d_loc[0] - 1, d_loc[1]) for d_loc in env.decoy_booth_locs]
        else:
            self.a_based_decoy_booth_locs = []
        for i in range(self.length):
            self.possible_pos.append((i, 0))
        self.correct_booth_coors = []
        for coor in env.booth_coors:
            self.correct_booth_coors.append(coor)
            if(coor[1] != 0):
                self.possible_pos.append(coor)
                self.non_main_correct_booth_pos_x.append(coor[0])
        for ad_loc in self.a_based_decoy_booth_locs:
            self.possible_pos.append(ad_loc)
            self.decoy_booth_pos_x.append(ad_loc[0])

        # Create index
        self.pos_index_dict = {}
        self.index_pos_dict = {}
        for i in range(len(self.possible_pos)):
            self.pos_index_dict[self.possible_pos[i]] = i
            self.index_pos_dict[i] = self.possible_pos[i]

        # transition matrix
        self.trans_mat = np.zeros((len(self.possible_pos), len(self.possible_pos)))
        for i, pos in enumerate(self.possible_pos):
            left_pos = pos[0] - 1 if pos[0] > 0 else 0
            right_pos = pos[0] + 1 if pos[0] < self.length - 1 else self.length - 1
            if(pos[1] != 0):
                # Decoy buttons *decoy buttons could be next to each other
                if(pos[1] == 1):
                    # Up
                    self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, i] = policy[LEFT] + policy[UP] + policy[NOOP]
                        if((right_pos, pos[1]) in self.a_based_decoy_booth_locs or (right_pos, pos[1]) in self.correct_booth_coors):
                            # Another decoy booth on the right
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        else:
                            # No booth on the right
                            self.trans_mat[i, i] +=  policy[RIGHT]
                    elif(right_pos == pos[0]):
                        # Right end of corridor
                        self.trans_mat[i, i] = policy[RIGHT] + policy[UP] + policy[NOOP]
                        if((left_pos, pos[1]) in self.a_based_decoy_booth_locs or (right_pos, pos[1]) in self.correct_booth_coors):
                            # Another decoy booth on the left
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        else:
                            # No booth on the left
                            self.trans_mat[i, i] +=  policy[LEFT]
                    else:
                        # Not on ends
                        if(((left_pos, pos[1]) in self.a_based_decoy_booth_locs or (left_pos, pos[1]) in self.correct_booth_coors) and ((right_pos, pos[1]) in self.a_based_decoy_booth_locs) or (right_pos, pos[1]) in self.correct_booth_coors):
                            # Booth on left and right
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[UP] + policy[NOOP]
                        elif((left_pos, pos[1]) in self.a_based_decoy_booth_locs or (left_pos, pos[1]) in self.correct_booth_coors):
                            # Booth on the left
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, i] = policy[UP] + policy[NOOP] + policy[RIGHT]
                        elif((right_pos, pos[1]) in self.a_based_decoy_booth_locs or (right_pos, pos[1]) in self.correct_booth_coors):
                            # Booth on the right
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[UP] + policy[NOOP] + policy[LEFT]
                        else:
                            # No booths on left and right
                            self.trans_mat[i, i] = policy[UP] + policy[NOOP] + policy[LEFT] + policy[RIGHT]


                else:
                    # Down
                    self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, i] = policy[LEFT] + policy[DOWN] + policy[NOOP]
                        if((right_pos, pos[1]) in self.a_based_decoy_booth_locs or (right_pos, pos[1]) in self.correct_booth_coors):
                            # Another decoy booth on the right
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        else:
                            # No booth on the right
                            self.trans_mat[i, i] +=  policy[RIGHT]
                    elif(right_pos == pos[0]):
                        # Right end of corridor
                        self.trans_mat[i, i] = policy[RIGHT] + policy[DOWN] + policy[NOOP]
                        if((left_pos, pos[1]) in self.a_based_decoy_booth_locs or (left_pos, pos[1]) in self.correct_booth_coors):
                            # Another decoy booth on the left
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        else:
                            # No booth on the left
                            self.trans_mat[i, i] +=  policy[LEFT]
                    else:
                        # Not on ends
                        if(((left_pos, pos[1]) in self.a_based_decoy_booth_locs or (left_pos, pos[1]) in self.correct_booth_coors) and ((right_pos, pos[1]) in self.a_based_decoy_booth_locs or (right_pos, pos[1]) in self.correct_booth_coors)):
                            # Booth on left and right
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[DOWN] + policy[NOOP]
                        elif((left_pos, pos[1]) in self.a_based_decoy_booth_locs or (left_pos, pos[1]) in self.correct_booth_coors):
                            # Booth on the left
                            self.trans_mat[i, self.pos_index_dict[left_pos]] = policy[LEFT]
                            self.trans_mat[i, i] = policy[DOWN] + policy[NOOP] + policy[RIGHT]
                        elif((right_pos, pos[1]) in self.a_based_decoy_booth_locs or (right_pos, pos[1]) in self.correct_booth_coors):
                            # Booth on the right
                            self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[DOWN] + policy[NOOP] + policy[LEFT]
                        else:
                            # No booths on left and right
                            self.trans_mat[i, i] = policy[DOWN] + policy[NOOP] + policy[LEFT] + policy[RIGHT]
            else:
                if(pos[0] in self.decoy_booth_pos_x or pos[0] in self.non_main_correct_booth_pos_x):
                    # same y as decoy buttons
                    possible_decoy_up = (pos[0], 1)
                    possible_decoy_down = (pos[0], -1)
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        if((possible_decoy_up in self.a_based_decoy_booth_locs or possible_decoy_up in self.correct_booth_coors) and (possible_decoy_down in self.a_based_decoy_booth_locs or possible_decoy_down in self.correct_booth_coors)):
                            # Up and down have decoy buttons
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, i] = policy[LEFT] + policy[NOOP]
                        elif(possible_decoy_up in self.a_based_decoy_booth_locs or possible_decoy_up in self.correct_booth_coors):
                            # Only Up
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, i] = policy[LEFT] + policy[NOOP] + policy[DOWN]
                        else:
                            # Only Down
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, i] = policy[LEFT] + policy[NOOP] + policy[UP]
                    elif(right_pos == pos[0]):
                        self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        if((possible_decoy_up in self.a_based_decoy_booth_locs or possible_decoy_up in self.correct_booth_coors) and (possible_decoy_down in self.a_based_decoy_booth_locs or possible_decoy_down in self.correct_booth_coors)):
                            # Up and down have decoy buttons
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, i] = policy[RIGHT] + policy[NOOP]
                        elif(possible_decoy_up in self.a_based_decoy_booth_locs or possible_decoy_up in self.correct_booth_coors):
                            # Only Up
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, i] = policy[RIGHT] + policy[NOOP] + policy[DOWN]
                        else:
                            # Only Down
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, i] = policy[RIGHT] + policy[NOOP] + policy[UP]
                    else:
                        if((possible_decoy_up in self.a_based_decoy_booth_locs or possible_decoy_up in self.correct_booth_coors) and (possible_decoy_down in self.a_based_decoy_booth_locs or possible_decoy_down in self.correct_booth_coors)):
                            # Up and down have decoy buttons
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] - 1, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] + 1, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[NOOP]
                        elif(possible_decoy_up in self.a_based_decoy_booth_locs or possible_decoy_up in self.correct_booth_coors):
                            # Only Up
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] + 1)]] = policy[UP]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] - 1, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] + 1, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[NOOP] + policy[DOWN]
                        else:
                            # Only Down
                            self.trans_mat[i, self.pos_index_dict[(pos[0], pos[1] - 1)]] = policy[DOWN]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] - 1, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, self.pos_index_dict[(pos[0] + 1, pos[1])]] = policy[RIGHT]
                            self.trans_mat[i, i] = policy[NOOP] + policy[UP]
                else:
                    # Only left and right
                    if(left_pos == pos[0]):
                        # Left End of corridor
                        self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        self.trans_mat[i, i] = policy[LEFT] + policy[UP] + policy[DOWN] + policy[NOOP]
                    elif(right_pos == pos[0]):
                        # Right End of corridor
                            self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                            self.trans_mat[i, i] = policy[RIGHT] + policy[UP] + policy[DOWN] + policy[NOOP]
                    else:
                        self.trans_mat[i, self.pos_index_dict[(left_pos, pos[1])]] = policy[LEFT]
                        self.trans_mat[i, self.pos_index_dict[(right_pos, pos[1])]] = policy[RIGHT]
                        self.trans_mat[i, i] = policy[UP] + policy[DOWN] + policy[NOOP]
            self.trans_mat[i, i] += policy[HINT_UP] + policy[HINT_DOWN]

        # Initial belief
        self.belief = np.zeros((1, len(self.possible_pos)))
        self.belief[0, self.pos_index_dict[tuple(self.start_pos)]] = 1.0
        self.pos_indices = [i for i in range(len(self.possible_pos))]

    def update_belief(self, comm_token = 0):
        # If the receiver receives a communication token, it should know that the sender is in the correct phone booth
        if(comm_token != 0):
            self.belief =  np.zeros((1, len(self.possible_pos)))
            for coor in self.correct_booth_coors:
                self.belief[0, self.pos_index_dict[coor]] = 1.0 / len(self.correct_booth_coors)
        else:
            self.belief = np.matmul(self.belief, self.trans_mat)

    def sample_belief(self):
        # Need normalization
        sampled_pos_index = np.random.choice(self.pos_indices, replace = False, p = self.belief.squeeze() / self.belief.sum() if self.belief.shape[1] > 1 else [(self.belief.item())/self.belief.sum()])
        if self.belief.shape[1] > 1:
            prob = self.belief.squeeze()[sampled_pos_index]
        else:
            prob = 1.0
        return  prob, self.index_pos_dict[sampled_pos_index]

    def reset_belief(self):
        self.belief = np.zeros((1, len(self.possible_pos)))
        self.belief[0, self.pos_index_dict[tuple(self.start_pos)]] = 1.0


"""
Testing belief models
"""

if __name__ == "__main__":
    d = {
        "lengths":(10,3),
        "starts": (5,1),
        "receiver_booth_loc":0,
        "booth_loc":9,
        "episode_limit": 40,
        "right_r": 1.0,
        "wrong_r": -0.5,
        "num_sender_decoy_booths": 2,
        "decoy_booths_fixed": 1,
        "use_intermediate_reward": False,
        "use_mi_shaping": False,
        "use_mi_loss": False
    }
    np.random.seed(1)
    random.seed(1)
    env = PBCMaze(env_args=d)
    env.reset()
    receiver_pi_0 = [0.2, 0.2, 0.2, 0.2, 0.2]
    sender_pi_0 = [1/7, 1/7, 1/7, 1/7, 1/7, 1/7, 1/7]
    rb_model = ReceiverBeliefModel(receiver_pi_0, env)
    sb_model = SenderBeliefModel(sender_pi_0, env)

    # env.render()
    # print(env.agent0_loc)
    # # print(rb_model.trans_mat)
    # # print(rb_model.start_pos)
    # # print(rb_model.belief)
    # # rb_model.update_belief()
    # # print(rb_model.belief)
    #

    f, axes = plt.subplots(2,2, sharex= True, sharey=True)
    cbar_ax = f.add_axes([.91,.3,.03,.4])
    sender_possible_pos = sb_model.index_pos_dict
    sender_heatmap_values = np.zeros((3, env.lengths[0]))
    for s_k, s_v in sender_possible_pos.items():
        sender_heatmap_values[s_v[1] + 1, s_v[0]] = sb_model.belief[0, s_k]
    print(sender_heatmap_values)

    mask = np.zeros_like(sender_heatmap_values)
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if((j, i - 1) not in sender_possible_pos.values()):
                mask[i, j] = True
    g1 = sns.heatmap(sender_heatmap_values, mask = mask, linewidth=2.0, annot = True, ax = axes[0, 0], square=True, annot_kws={"fontsize":8}, cbar = True, cbar_ax = cbar_ax,  vmin = 0,vmax = 1)
    g1.set(xticklabels=[], yticklabels=[])
    g1.set(xlabel=None)
    g1.tick_params(bottom=False, left = False)
    # g1.set_ylabel('Probability')
    g1.set(title = 'Step = 1')
    sb_model.update_belief()

    sender_possible_pos = sb_model.index_pos_dict
    sender_heatmap_values = np.zeros((3, env.lengths[0]))
    for s_k, s_v in sender_possible_pos.items():
        sender_heatmap_values[s_v[1] + 1, s_v[0]] = sb_model.belief[0, s_k]
    print(sender_heatmap_values)

    mask = np.zeros_like(sender_heatmap_values)
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if((j, i - 1) not in sender_possible_pos.values()):
                mask[i, j] = True
    g2 = sns.heatmap(sender_heatmap_values, mask = mask, linewidth=2.0, annot = True, ax = axes[0, 1], square=True, annot_kws={"fontsize":8}, cbar = True, cbar_ax = cbar_ax,  vmin = 0,vmax = 1)
    g2.set(xticklabels=[], yticklabels=[])
    g2.set(xlabel=None)
    g2.tick_params(bottom=False, left = False)
    # g1.set_ylabel('Probability')
    g2.set(title = 'Step = 2')
    sb_model.update_belief()

    sender_possible_pos = sb_model.index_pos_dict
    sender_heatmap_values = np.zeros((3, env.lengths[0]))
    for s_k, s_v in sender_possible_pos.items():
        sender_heatmap_values[s_v[1] + 1, s_v[0]] = sb_model.belief[0, s_k]
    print(sender_heatmap_values)

    mask = np.zeros_like(sender_heatmap_values)
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if((j, i - 1) not in sender_possible_pos.values()):
                mask[i, j] = True
    g3 = sns.heatmap(sender_heatmap_values, mask = mask, linewidth=2.0, annot = True, ax = axes[1, 0], square=True, annot_kws={"fontsize":8}, cbar = True, cbar_ax = cbar_ax, vmin = 0,vmax = 1)
    g3.set(xticklabels=[], yticklabels=[])
    g3.set(xlabel=None)
    g3.tick_params(bottom=False, left = False)
    # g1.set_ylabel('Probability')
    g3.set(title = 'Step = 3')
    sb_model.update_belief()

    sender_possible_pos = sb_model.index_pos_dict
    sender_heatmap_values = np.zeros((3, env.lengths[0]))
    for s_k, s_v in sender_possible_pos.items():
        sender_heatmap_values[s_v[1] + 1, s_v[0]] = sb_model.belief[0, s_k]
    print(sender_heatmap_values)

    mask = np.zeros_like(sender_heatmap_values)
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if((j, i - 1) not in sender_possible_pos.values()):
                mask[i, j] = True
    g4 = sns.heatmap(sender_heatmap_values, mask = mask, linewidth=2.0, annot = True, ax = axes[1, 1], square=True, annot_kws={"fontsize":8}, cbar = True, cbar_ax = cbar_ax, vmin = 0,vmax = 1)
    g4.set(xticklabels=[], yticklabels=[])
    g4.set(xlabel=None)
    g4.tick_params(bottom=False, left = False)
    # g1.set_ylabel('Probability')
    g4.set(title = 'Step = 4')
    sb_model.update_belief()

    plt.show()
    # print(sb_model.pos_index_dict)
    # # print(sb_model.trans_mat)
    # print(sb_model.start_pos)
    # print(sb_model.belief)
    # print(sb_model.sample_belief())
    # sb_model.update_belief()
    # print(sb_model.belief)
